import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import networkx as nx
from scipy.sparse import coo_matrix
import numpy as np
import matplotlib.pyplot as plt

from physics_gatv2 import PhysicsAwareGATv2
from flow import *
from evaluation import *
from util import sparse_tensor_from_coo_matrix
from tikhonov import *
from mlp import *
from gcn import *
from bilevel_mlp import *
from bilevel_gcn import *

# ---------------------------------------------------------------------
# convenience: dense incidence matrix once per run
# ---------------------------------------------------------------------
_incidence_cache = {}
def incidence_dense(G, device=None, dtype=torch.float32):
    key = (id(G), device, dtype)
    if key not in _incidence_cache:
        Bsp = nx.incidence_matrix(G, oriented=True)  # SciPy sparse
        B = sparse_tensor_from_coo_matrix(Bsp).to(device, dtype=dtype)
        _incidence_cache[key] = B.to_dense()
    return _incidence_cache[key]


# =====================================================================
#  GroupActionFlowEstimator
# =====================================================================
class GroupActionFlowEstimator(nn.Module):
    """
    Low-level solver that *learns* how to impute missing flows on a single graph.

    Parameters
    ----------
    G : networkx.DiGraph
        The topology.  Edge order is fixed for the entire lifetime of the object.
    dgl_G : dgl.DGLGraph
        Same graph as G, but stored in DGL format for fast message passing.
    priors : dict
        Any extra per-edge feature dict (unused in this snippet but kept for API).
    lamb : float
        Initial value of Tikhonov weight λ (>0).  Optimised in log-space.
    net : PhysicsAwareGATv2
        Edge encoder that produces (diag_out, low_out) per edge.
    n_folds : int
        #inner CV folds for outer-loop training.
    edge_map : dict[(u,v) → int]
        Map each edge to “column id” in vectors of length m.
    outer_n_iter : int
        Max. #outer iterations (epochs).
    outer_lr : float
        Learning-rate for outer optimiser.
    alpha : float
        Not used directly here; kept for backwards compatibility.
    n_basis_actions : int
        k – #basis vectors kept from null space (default 32).
    nonneg, early_stop, injections, div_threshold : misc knobs.
    """
    # -------------------------------------------------------------------------
    # constructor
    # -------------------------------------------------------------------------
    def __init__(
        self,
        G,
        dgl_G,
        priors,
        lamb,
        net,
        n_folds,
        edge_map,
        outer_n_iter,
        outer_lr,
        alpha=0.1,
        n_basis_actions=32,
        nonneg=False,
        early_stop=10,
        *,
        injections=None,
        div_threshold=0
    ):
        super().__init__()

        # graph & data
        self.G = G
        self.dgl_G = dgl_G
        self.priors = priors
        self.edge_map = edge_map
        self.n_folds = n_folds

        # make λ learnable (in log-space)
        init_val = float(lamb)
        self.log_lamb = nn.Parameter(
            torch.log(torch.tensor(init_val, dtype=torch.float32))
        )

        # network & optimizer settings
        self.net = net
        self.outer_n_iter = outer_n_iter
        self.outer_lr = outer_lr
        self.early_stop = early_stop
        self.nonneg = nonneg

        # group-action basis settings
        self.n_basis_actions = n_basis_actions
        self.alpha = alpha

        # optional manual injections or threshold
        self.c_manual = injections
        self.div_threshold = div_threshold

        # cached incidence matrix
        dev = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.B = incidence_dense(G, device=dev)

        # build a null-space basis U (k × m)
        self.init_group_action_basis()

        # Learnable attention weights  w_i ∈ ℝ^{hid}  (k × hid)
        hid = 1 + (net.low_rank_dim if hasattr(net, 'low_rank_dim') else 4)
        # w_params shape: (k, hid)
        self.w_params = nn.Parameter(
            torch.randn(self.n_basis_actions, hid, device=self.B.device)
        )

        # move to GPU if available
        if torch.cuda.is_available():
            self.cuda()
            self.basis_actions = self.basis_actions.cuda()
            self.w_params = self.w_params.cuda()
            self.log_lamb = self.log_lamb.cuda()

        # placeholders for plotting/debug
        self.last_attention_weights = None
        self.last_delta0 = None
        self.last_x_missing = None

    # -------------------------------------------------------------------------
    # S2 – build orthonormal basis  U⊤ ∈ ℝ^{k×m},  each row = basis vector u⁽i⁾
    #      • divergence-free  (B u = 0)
    #      • spans ker B  but truncated to k vectors
    # -------------------------------------------------------------------------
    def init_group_action_basis(self):
        device = self.B.device
        m = self.G.number_of_edges()
        k = self.n_basis_actions

        # random seed matrix
        R = torch.randn(m, k, device=device)

        # projector onto ker(B)
        B = self.B  # (n × m)
        P_bal = torch.eye(m, device=device) - B.t() @ torch.linalg.pinv(B @ B.t()) @ B

        # orthonormalise columns
        U, _ = torch.linalg.qr(P_bal @ R)  # (m × k)
        # store as rows = basis vectors
        self.basis_actions = nn.Parameter(U.T)  # (k × m)

    # -------------------------------------------------------------------------
    # S1 – produce *one* balanced completion that preserves observed flows
    #      `partial_flows`  = hat_f   (m,1)  with zeros on missing edges
    # -------------------------------------------------------------------------
    def arbitrary_balanced_completion(self, partial_flows, mask_miss):
        """
        Return Δ⁰  such that B (partial_flows + Δ⁰) = c   and Δ⁰ vanishes on obs.
        We choose the *minimum-ℓ₂-norm* solution over the missing coordinates.
        """
        B = self.B.to(partial_flows)
        div_p = B @ partial_flows  # (n,1)

        if self.c_manual is not None:
            target = self.c_manual.to(partial_flows)
        else:
            mask = div_p.abs() > self.div_threshold
            target = div_p * mask.float()

        rhs = target - div_p
        B_miss = B[:, mask_miss]
        BBt = B_miss @ B_miss.t()
        delta_miss = B_miss.t() @ torch.linalg.pinv(BBt) @ rhs

        comp = torch.zeros_like(partial_flows)
        comp[mask_miss] = delta_miss
        return comp

    # -------------------------------------------------------------------------
    # S3 – compute divergence-free correction  Δ = U α   with attention weights
    # -------------------------------------------------------------------------
    def compute_group_action(self, partial_flows, mask_miss, edge_features=None):
        """
        1.  Run the GAT-v2 encoder once → per-edge embedding  h_e  ∈ ℝ^{hid}.
        2.  For every basis vector u⁽i⁾, compute compatibility scores
              q_{e,i} = (w_iᵀ h_e) · |u⁽i⁾_e|
        3.  Average over missing edges → s_i,  softmax → α_i.
        4.  Return Δ = U α  restricted to missing edges.
        """
        # Edge embeddings  (diag_out, low_out)  → concat (m, hid)
        all_feat = self.dgl_G.ndata['feat']
        diag_out, low_out, _ = self.net(
            self.dgl_G, all_feat, edge_features, return_attention=True
        )  # shapes: (m,), (m, low_rank_dim)

        # assemble (m, hid) edge embeddings
        edge_emb = torch.cat([diag_out.unsqueeze(1), low_out], dim=1)
        emb_miss = edge_emb[mask_miss]  # (m_miss, hid)

        # 2) compatibility scores: q_{e,i} = (w_i^T H_e) * |u^{(i)}_e|
        #    w_params: (k, hid) → transpose to (hid, k)
        scores_raw = emb_miss @ self.w_params.t()  # (m_miss, k)

        # restrict basis to missing columns
        basis_sub = self.basis_actions[:, mask_miss]  # (k, m_miss)
        # take absolute values
        basis_abs = torch.abs(basis_sub.T)  # (m_miss, k)

        # elementwise multiply to get Q
        Q = scores_raw * basis_abs  # (m_miss, k)

        # 3) aggregate and softmax
        s = Q.mean(dim=0, keepdim=True)  # (1, k)
        alpha = F.softmax(s, dim=1)  # (1, k)
        self.last_attention_weights = alpha.view(-1).detach().cpu().numpy()

        # 4) form divergence-free correction Δ = U * α
        delta = (alpha @ basis_sub).view(-1, 1)  # (m_miss, 1)
        return delta

    # -------------------------------------------------------------------------
    # helper: apply Δ  +  Δ⁰  to get a complete flow vector of length m
    # -------------------------------------------------------------------------
    def apply_group_action(self, partial_flows, mask_miss, action):
        comp = self.arbitrary_balanced_completion(partial_flows, mask_miss)
        full_act = torch.zeros_like(partial_flows)
        full_act[mask_miss] = action.squeeze(-1)
        return partial_flows + comp + full_act

    # -------------------------------------------------------------------------
    # Outer-loop bilevel training  (implicit differentiation inside predict_flows)
    # -------------------------------------------------------------------------
    def train_model(self, train_flows, valid_flows, verbose=False):
        """
        Hyper-parameter learning loop over K folds.
        Learns:
            • GAT-v2 weights θ
            • null-space basis  U  (fine-tuned)
            • attention parameters  {w_i}
            • Tikhonov weight  λ
        """
        # Cross-validation splits on *edge indices*
        int_folds = generate_folds({**train_flows, **valid_flows}, self.n_folds)

        # include basis, log_lamb, w_params in optimizer
        params = (
            list(self.net.parameters())
            + [self.basis_actions]
            + [self.log_lamb]
            + [self.w_params]
        )
        self.optimizer = torch.optim.Adam(params, lr=self.outer_lr)

        dev0 = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        m = self.G.number_of_edges()
        train_losses = []
        best_state = None

        for epoch in range(self.outer_n_iter):
            self.optimizer.zero_grad()
            outer_loss = torch.zeros(1, device=dev0)

            for int_train, int_test in int_folds:
                if not int_train:
                    continue
                try:
                    # prepare least-squares system
                    A, b, idx_map = lsq_matrix_flow(self.G, int_train)

                    # build hat_f
                    hat_f = torch.zeros(m, device=dev0)
                    for edge, val in int_train.items():
                        hat_f[self.edge_map[edge]] = val

                    # build missing index vector
                    m_miss = len(idx_map)
                    missing_idx = torch.empty(m_miss, dtype=torch.long, device=dev0)
                    for edge, col in idx_map.items():
                        missing_idx[col] = self.edge_map[edge]

                    # end-to-end solve
                    x_hat = self.predict_flows(A, b, hat_f, missing_idx)

                    # compute fold loss
                    y_val, mapp = get_fold_flow_data(self.G, int_train, int_test)
                    loss = nn.MSELoss()(torch.sparse.mm(mapp, x_hat), y_val)
                    if torch.isnan(loss):
                        raise RuntimeError("NaN loss")
                    outer_loss += loss

                except Exception as e:
                    if verbose:
                        print("[skip fold]", e)
                    continue

            if torch.isnan(outer_loss):
                print("NaN outer loss – aborting")
                break

            # checkpoint
            if best_state is None or outer_loss.item() < min(train_losses, default=float('inf')):
                best_state = {
                    'net': self.net.state_dict(),
                    'basis': self.basis_actions.data.clone(),
                    'log_lamb': self.log_lamb.data.clone(),
                    'w_params': self.w_params.data.clone()
                }

            outer_loss.backward()
            torch.nn.utils.clip_grad_norm_(params, max_norm=1.0)
            self.optimizer.step()
            train_losses.append(outer_loss.item())

            if verbose:
                print(f"[epoch {epoch:02d}] loss={outer_loss.item():.4f}")

            # early stopping
            if (
                epoch > self.early_stop
                and train_losses[-1] > sum(train_losses[-self.early_stop:]) / self.early_stop
            ):
                if verbose:
                    print("Early stopping")
                break

        # final solve on all observed edges
        A, b, idx_map = lsq_matrix_flow(self.G, {**train_flows, **valid_flows})
        self.index = idx_map
        hat_f = torch.zeros(m, device=dev0)
        for edge, val in {**train_flows, **valid_flows}.items():
            hat_f[self.edge_map[edge]] = val

        m_miss = len(idx_map)
        missing_idx = torch.empty(m_miss, dtype=torch.long, device=dev0)
        for edge, col in idx_map.items():
            missing_idx[col] = self.edge_map[edge]

        self.x_star = self.predict_flows(A, b, hat_f, missing_idx)

    # -------------------------------------------------------------------------
    # S4 – implicit Tikhonov solve  (differentiable)
    # -------------------------------------------------------------------------
    def predict_flows(self, A, b, partial_flows, missing_idx):
        device = partial_flows.device
        m_miss = missing_idx.shape[0]

        mask_full = torch.zeros(partial_flows.shape[0], dtype=torch.bool, device=device)
        mask_full[missing_idx] = True
        comp_full = self.arbitrary_balanced_completion(partial_flows, mask_full)

        action = self.compute_group_action(partial_flows.view(-1, 1), mask_full)
        action = action.view(-1)

        delta0 = comp_full[mask_full] + action
        self.last_delta0 = delta0.detach().cpu().numpy()

        lamb = torch.exp(self.log_lamb)

        if not torch.is_tensor(A):
            A_mat = sparse_tensor_from_coo_matrix(A).to_dense().to(device)
        else:
            A_mat = A.to(device)
        b_vec = torch.as_tensor(b, dtype=torch.float32, device=device).view(-1, 1)

        # Tikhonov: solve (AᵀA + λI)x = Aᵀb + λ δ⁽⁰⁾
        ATA = A_mat.t() @ A_mat
        M = ATA + lamb * torch.eye(m_miss, device=device)
        rhs = A_mat.t() @ b_vec
        rhs = rhs + lamb * delta0.view(-1, 1)
        x_missing = torch.linalg.solve(M, rhs)

        self.last_x_missing = x_missing.view(-1).detach().cpu().numpy()
        return x_missing

    # -----------------------------------------------------------------
    #  Plotting utilities
    # -----------------------------------------------------------------
    def plot_attention_histogram(self, bins: int = 50, save_path: str = "attention_weights_bikec.txt"):
        # --- plotting as before ---
        plt.figure(figsize=(6, 4))
        plt.hist(self.last_attention_weights, bins=bins, color='skyblue', edgecolor='black', linewidth=0.4)
        plt.title("Attention Weights Histogram")
        plt.xlabel("Attention Weight")
        plt.ylabel("Frequency")
        plt.tight_layout()
        plt.show()

        # --- saving to TXT ---
        # self.last_attention_weights should be a 1D array-like
        np.savetxt(
            save_path,
            self.last_attention_weights,
            fmt="%.6e",
            header="Attention weights, one per line",
            comments=""
        )
        print(f"Saved attention weights to {save_path}")


# =============================================================================
#  Thin public wrapper  – hides all the plumbing from end-users              ══
# =============================================================================
class GroupActionFlow(nn.Module):
    """
    High-level wrapper that exposes .train() and .predict()
    """

    def __init__(self, G, features, params, verbose=False):
        super().__init__()
        self.G = G
        self.features = features
        self.params = params
        self.verbose = verbose

    def train(self, train_flows, valid_flows):
        from gcn import create_dgl_graph

        self.G_dgl, self.edge_map = create_dgl_graph(
            self.G,
            self.features,
            directed=self.params.get('nonneg', False)
        )
        in_dim = self.features[next(iter(self.features))].shape[0]

        self.gat_net = PhysicsAwareGATv2(
            in_feats=in_dim,
            hidden_size=self.params['n_hidden'],
            n_iter=self.params['outer_n_iter'],
            lr=self.params['outer_lr'],
            num_heads=self.params.get('num_heads', 4),
            low_rank_dim=self.params.get('low_rank_dim', 4),
            early_stop=self.params['early_stop'],
            output_activation=torch.sigmoid,
            use_edge_features=True
        )

        self.ga_solver = GroupActionFlowEstimator(
            G=self.G,
            dgl_G=self.G_dgl,
            priors=self.params['priors'],
            lamb=self.params.get('lambda', 0.01),
            net=self.gat_net,
            n_folds=self.params['n_folds'],
            edge_map=self.edge_map,
            outer_n_iter=self.params['outer_n_iter'],
            outer_lr=self.params['outer_lr'],
            alpha=self.params.get('alpha', 0.1),
            n_basis_actions=self.params.get('n_basis_actions', 32),
            nonneg=self.params['nonneg'],
            early_stop=self.params['early_stop'],
            injections=self.params.get('injections'),
            div_threshold=self.params.get('div_threshold', 0)
        )

        self.ga_solver.train_model(train_flows, valid_flows, verbose=self.verbose)

    # -------------------------------------------------------------------------
    # Predict missing flows on a *new* set of observed edges
    # -------------------------------------------------------------------------
    def predict(self, test_flows):
        """
        Parameters
        ----------
        test_flows : dict[(u,v) → float]
            Observed edge subset.  The trained model infers the rest.

        Returns
        -------
        dict[(u,v) → float]
            Imputed flows on *all* missing edges in test_flows’ mask.
        """
        preds = get_dict_flows_from_tensor(
            self.ga_solver.index,
            self.ga_solver.x_star,
            test_flows
        )

        # optional: plot attention weights
        self.ga_solver.plot_attention_histogram()
        return preds
